!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Routines needed for kpoint calculation
!> \par History
!>       2014.07 created [JGH]
!>       2014.11 unified k-point and gamma-point code [Ole Schuett]
!> \author JGH
! **************************************************************************************************
MODULE kpoint_methods
   USE atomic_kind_types,               ONLY: get_atomic_kind
   USE cell_types,                      ONLY: cell_type,&
                                              real_to_scaled
   USE cp_blacs_env,                    ONLY: cp_blacs_env_create,&
                                              cp_blacs_env_type
   USE cp_control_types,                ONLY: dft_control_type
   USE cp_dbcsr_api,                    ONLY: &
        dbcsr_copy, dbcsr_create, dbcsr_deallocate_matrix, dbcsr_distribution_get, &
        dbcsr_distribution_type, dbcsr_get_block_p, dbcsr_get_info, dbcsr_iterator_blocks_left, &
        dbcsr_iterator_next_block, dbcsr_iterator_start, dbcsr_iterator_stop, dbcsr_iterator_type, &
        dbcsr_p_type, dbcsr_set, dbcsr_type, dbcsr_type_antisymmetric, dbcsr_type_no_symmetry, &
        dbcsr_type_symmetric
   USE cp_dbcsr_cp2k_link,              ONLY: cp_dbcsr_alloc_block_from_nbl
   USE cp_dbcsr_operations,             ONLY: copy_fm_to_dbcsr
   USE cp_fm_basic_linalg,              ONLY: cp_fm_column_scale
   USE cp_fm_pool_types,                ONLY: cp_fm_pool_p_type,&
                                              fm_pool_create_fm,&
                                              fm_pool_give_back_fm
   USE cp_fm_struct,                    ONLY: cp_fm_struct_type
   USE cp_fm_types,                     ONLY: &
        copy_info_type, cp_fm_cleanup_copy_general, cp_fm_create, cp_fm_finish_copy_general, &
        cp_fm_get_info, cp_fm_release, cp_fm_start_copy_general, cp_fm_to_fm, cp_fm_type
   USE cp_log_handling,                 ONLY: cp_logger_get_default_io_unit
   USE cryssym,                         ONLY: apply_rotation_coord,&
                                              crys_sym_gen,&
                                              csym_type,&
                                              kpoint_gen,&
                                              print_crys_symmetry,&
                                              print_kp_symmetry,&
                                              release_csym_type
   USE fermi_utils,                     ONLY: fermikp,&
                                              fermikp2
   USE input_constants,                 ONLY: smear_fermi_dirac
   USE kinds,                           ONLY: dp
   USE kpoint_types,                    ONLY: get_kpoint_info,&
                                              kpoint_env_create,&
                                              kpoint_env_p_type,&
                                              kpoint_env_type,&
                                              kpoint_sym_create,&
                                              kpoint_sym_type,&
                                              kpoint_type
   USE mathconstants,                   ONLY: twopi
   USE memory_utilities,                ONLY: reallocate
   USE message_passing,                 ONLY: mp_cart_type,&
                                              mp_para_env_type
   USE parallel_gemm_api,               ONLY: parallel_gemm
   USE particle_types,                  ONLY: particle_type
   USE qs_matrix_pools,                 ONLY: mpools_create,&
                                              mpools_get,&
                                              mpools_rebuild_fm_pools,&
                                              qs_matrix_pools_type
   USE qs_mo_types,                     ONLY: allocate_mo_set,&
                                              get_mo_set,&
                                              init_mo_set,&
                                              mo_set_type,&
                                              set_mo_set
   USE qs_neighbor_list_types,          ONLY: get_iterator_info,&
                                              get_neighbor_list_set_p,&
                                              neighbor_list_iterate,&
                                              neighbor_list_iterator_create,&
                                              neighbor_list_iterator_p_type,&
                                              neighbor_list_iterator_release,&
                                              neighbor_list_set_p_type
   USE scf_control_types,               ONLY: smear_type
   USE util,                            ONLY: get_limit
#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'kpoint_methods'

   PUBLIC :: kpoint_initialize, kpoint_env_initialize, kpoint_initialize_mos, kpoint_initialize_mo_set
   PUBLIC :: kpoint_init_cell_index, kpoint_set_mo_occupation
   PUBLIC :: kpoint_density_matrices, kpoint_density_transform
   PUBLIC :: rskp_transform

! **************************************************************************************************

CONTAINS

! **************************************************************************************************
!> \brief Generate the kpoints and initialize the kpoint environment
!> \param kpoint       The kpoint environment
!> \param particle_set Particle types and coordinates
!> \param cell         Computational cell information
! **************************************************************************************************
   SUBROUTINE kpoint_initialize(kpoint, particle_set, cell)

      TYPE(kpoint_type), POINTER                         :: kpoint
      TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
      TYPE(cell_type), POINTER                           :: cell

      CHARACTER(LEN=*), PARAMETER                        :: routineN = 'kpoint_initialize'

      INTEGER                                            :: handle, i, ik, iounit, ir, is, natom, &
                                                            nr, ns
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: atype
      LOGICAL                                            :: spez
      REAL(KIND=dp)                                      :: wsum
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: coord, scoord
      TYPE(csym_type)                                    :: crys_sym
      TYPE(kpoint_sym_type), POINTER                     :: kpsym

      CALL timeset(routineN, handle)

      CPASSERT(ASSOCIATED(kpoint))

      SELECT CASE (kpoint%kp_scheme)
      CASE ("NONE")
         ! do nothing
      CASE ("GAMMA")
         kpoint%nkp = 1
         ALLOCATE (kpoint%xkp(3, 1), kpoint%wkp(1))
         kpoint%xkp(1:3, 1) = 0.0_dp
         kpoint%wkp(1) = 1.0_dp
         ALLOCATE (kpoint%kp_sym(1))
         NULLIFY (kpoint%kp_sym(1)%kpoint_sym)
         CALL kpoint_sym_create(kpoint%kp_sym(1)%kpoint_sym)
      CASE ("MONKHORST-PACK", "MACDONALD")

         IF (.NOT. kpoint%symmetry) THEN
            ! we set up a random molecule to avoid any possible symmetry
            natom = 10
            ALLOCATE (coord(3, natom), scoord(3, natom), atype(natom))
            DO i = 1, natom
               atype(i) = i
               coord(1, i) = SIN(i*0.12345_dp)
               coord(2, i) = COS(i*0.23456_dp)
               coord(3, i) = SIN(i*0.34567_dp)
               CALL real_to_scaled(scoord(1:3, i), coord(1:3, i), cell)
            END DO
         ELSE
            natom = SIZE(particle_set)
            ALLOCATE (scoord(3, natom), atype(natom))
            DO i = 1, natom
               CALL get_atomic_kind(atomic_kind=particle_set(i)%atomic_kind, kind_number=atype(i))
               CALL real_to_scaled(scoord(1:3, i), particle_set(i)%r(1:3), cell)
            END DO
         END IF
         IF (kpoint%verbose) THEN
            iounit = cp_logger_get_default_io_unit()
         ELSE
            iounit = -1
         END IF

         CALL crys_sym_gen(crys_sym, scoord, atype, cell%hmat, delta=kpoint%eps_geo, iounit=iounit)
         CALL kpoint_gen(crys_sym, kpoint%nkp_grid, symm=kpoint%symmetry, shift=kpoint%kp_shift, &
                         full_grid=kpoint%full_grid)
         kpoint%nkp = crys_sym%nkpoint
         ALLOCATE (kpoint%xkp(3, kpoint%nkp), kpoint%wkp(kpoint%nkp))
         wsum = SUM(crys_sym%wkpoint)
         DO ik = 1, kpoint%nkp
            kpoint%xkp(1:3, ik) = crys_sym%xkpoint(1:3, ik)
            kpoint%wkp(ik) = crys_sym%wkpoint(ik)/wsum
         END DO

         ! print output
         IF (kpoint%symmetry) CALL print_crys_symmetry(crys_sym)
         IF (kpoint%symmetry) CALL print_kp_symmetry(crys_sym)

         ! transfer symmetry information
         ALLOCATE (kpoint%kp_sym(kpoint%nkp))
         DO ik = 1, kpoint%nkp
            NULLIFY (kpoint%kp_sym(ik)%kpoint_sym)
            CALL kpoint_sym_create(kpoint%kp_sym(ik)%kpoint_sym)
            kpsym => kpoint%kp_sym(ik)%kpoint_sym
            IF (crys_sym%symlib .AND. .NOT. crys_sym%fullgrid .AND. crys_sym%istriz == 1) THEN
               ! set up the symmetrization information
               kpsym%nwght = NINT(crys_sym%wkpoint(ik))
               ns = kpsym%nwght
               ! to be done correctly
               IF (ns > 1) THEN
                  kpsym%apply_symmetry = .TRUE.
                  natom = SIZE(particle_set)
                  ALLOCATE (kpsym%rot(3, 3, ns))
                  ALLOCATE (kpsym%xkp(3, ns))
                  ALLOCATE (kpsym%f0(natom, ns))
                  nr = 0
                  DO is = 1, SIZE(crys_sym%kplink, 2)
                     IF (crys_sym%kplink(2, is) == ik) THEN
                        nr = nr + 1
                        ir = crys_sym%kpop(is)
                        kpsym%rot(1:3, 1:3, nr) = crys_sym%rotations(1:3, 1:3, ir)
                        kpsym%xkp(1:3, nr) = crys_sym%kpmesh(1:3, is)
                        CALL apply_rotation_coord(kpsym%f0(1:natom, nr), crys_sym, ir)
                     END IF
                  END DO
               END IF
            END IF
         END DO

         CALL release_csym_type(crys_sym)
         DEALLOCATE (scoord, atype)

      CASE ("GENERAL")
         ! default: no symmetry settings
         ALLOCATE (kpoint%kp_sym(kpoint%nkp))
         DO i = 1, kpoint%nkp
            NULLIFY (kpoint%kp_sym(i)%kpoint_sym)
            CALL kpoint_sym_create(kpoint%kp_sym(i)%kpoint_sym)
         END DO
      CASE DEFAULT
         CPASSERT(.FALSE.)
      END SELECT

      ! check for consistency of options
      SELECT CASE (kpoint%kp_scheme)
      CASE ("NONE")
         ! don't use k-point code
      CASE ("GAMMA")
         CPASSERT(kpoint%nkp == 1)
         CPASSERT(SUM(ABS(kpoint%xkp)) <= 1.e-12_dp)
         CPASSERT(kpoint%wkp(1) == 1.0_dp)
         CPASSERT(.NOT. kpoint%symmetry)
      CASE ("GENERAL")
         CPASSERT(.NOT. kpoint%symmetry)
         CPASSERT(kpoint%nkp >= 1)
      CASE ("MONKHORST-PACK", "MACDONALD")
         CPASSERT(kpoint%nkp >= 1)
      END SELECT
      IF (kpoint%use_real_wfn) THEN
         ! what about inversion symmetry?
         ikloop: DO ik = 1, kpoint%nkp
            DO i = 1, 3
               spez = (kpoint%xkp(i, ik) == 0.0_dp .OR. kpoint%xkp(i, ik) == 0.5_dp)
               IF (.NOT. spez) EXIT ikloop
            END DO
         END DO ikloop
         IF (.NOT. spez) THEN
            ! Warning: real wfn might be wrong for this system
            CALL cp_warn(__LOCATION__, &
                         "A calculation using real wavefunctions is requested. "// &
                         "We could not determine if the symmetry of the system allows real wavefunctions. ")
         END IF
      END IF

      CALL timestop(handle)

   END SUBROUTINE kpoint_initialize

! **************************************************************************************************
!> \brief Initialize the kpoint environment
!> \param kpoint       Kpoint environment
!> \param para_env ...
!> \param blacs_env ...
!> \param with_aux_fit ...
! **************************************************************************************************
   SUBROUTINE kpoint_env_initialize(kpoint, para_env, blacs_env, with_aux_fit)

      TYPE(kpoint_type), INTENT(INOUT)                   :: kpoint
      TYPE(mp_para_env_type), INTENT(IN), TARGET         :: para_env
      TYPE(cp_blacs_env_type), INTENT(IN), TARGET        :: blacs_env
      LOGICAL, INTENT(IN), OPTIONAL                      :: with_aux_fit

      CHARACTER(LEN=*), PARAMETER :: routineN = 'kpoint_env_initialize'

      INTEGER                                            :: handle, igr, ik, ikk, ngr, niogrp, nkp, &
                                                            nkp_grp, nkp_loc, npe, unit_nr
      INTEGER, DIMENSION(2)                              :: dims, pos
      LOGICAL                                            :: aux_fit
      TYPE(kpoint_env_p_type), DIMENSION(:), POINTER     :: kp_aux_env, kp_env
      TYPE(kpoint_env_type), POINTER                     :: kp
      TYPE(mp_cart_type)                                 :: comm_cart
      TYPE(mp_para_env_type), POINTER                    :: para_env_inter_kp, para_env_kp

      CALL timeset(routineN, handle)

      IF (PRESENT(with_aux_fit)) THEN
         aux_fit = with_aux_fit
      ELSE
         aux_fit = .FALSE.
      END IF

      kpoint%para_env => para_env
      CALL kpoint%para_env%retain()
      kpoint%blacs_env_all => blacs_env
      CALL kpoint%blacs_env_all%retain()

      CPASSERT(.NOT. ASSOCIATED(kpoint%kp_env))
      IF (aux_fit) THEN
         CPASSERT(.NOT. ASSOCIATED(kpoint%kp_aux_env))
      END IF

      NULLIFY (kp_env, kp_aux_env)
      nkp = kpoint%nkp
      npe = para_env%num_pe
      IF (npe == 1) THEN
         ! only one process available -> owns all kpoints
         ALLOCATE (kp_env(nkp))
         DO ik = 1, nkp
            NULLIFY (kp_env(ik)%kpoint_env)
            CALL kpoint_env_create(kp_env(ik)%kpoint_env)
            kp => kp_env(ik)%kpoint_env
            kp%nkpoint = ik
            kp%wkp = kpoint%wkp(ik)
            kp%xkp(1:3) = kpoint%xkp(1:3, ik)
            kp%is_local = .TRUE.
         END DO
         kpoint%kp_env => kp_env

         IF (aux_fit) THEN
            ALLOCATE (kp_aux_env(nkp))
            DO ik = 1, nkp
               NULLIFY (kp_aux_env(ik)%kpoint_env)
               CALL kpoint_env_create(kp_aux_env(ik)%kpoint_env)
               kp => kp_aux_env(ik)%kpoint_env
               kp%nkpoint = ik
               kp%wkp = kpoint%wkp(ik)
               kp%xkp(1:3) = kpoint%xkp(1:3, ik)
               kp%is_local = .TRUE.
            END DO

            kpoint%kp_aux_env => kp_aux_env
         END IF

         ALLOCATE (kpoint%kp_dist(2, 1))
         kpoint%kp_dist(1, 1) = 1
         kpoint%kp_dist(2, 1) = nkp
         kpoint%kp_range(1) = 1
         kpoint%kp_range(2) = nkp

         ! parallel environments
         kpoint%para_env_kp => para_env
         CALL kpoint%para_env_kp%retain()
         kpoint%para_env_inter_kp => para_env
         CALL kpoint%para_env_inter_kp%retain()
         kpoint%iogrp = .TRUE.
         kpoint%nkp_groups = 1
      ELSE
         IF (kpoint%parallel_group_size == -1) THEN
            ! maximum parallelization over kpoints
            ! making sure that the group size divides the npe and the nkp_grp the nkp
            ! in the worst case, there will be no parallelism over kpoints.
            DO igr = npe, 1, -1
               IF (MOD(npe, igr) .NE. 0) CYCLE
               nkp_grp = npe/igr
               IF (MOD(nkp, nkp_grp) .NE. 0) CYCLE
               ngr = igr
            END DO
         ELSE IF (kpoint%parallel_group_size == 0) THEN
            ! no parallelization over kpoints
            ngr = npe
         ELSE IF (kpoint%parallel_group_size > 0) THEN
            ngr = MIN(kpoint%parallel_group_size, npe)
         ELSE
            CPASSERT(.FALSE.)
         END IF
         nkp_grp = npe/ngr
         ! processor dimensions
         dims(1) = ngr
         dims(2) = nkp_grp
         CPASSERT(MOD(nkp, nkp_grp) == 0)
         nkp_loc = nkp/nkp_grp

         IF ((dims(1)*dims(2) /= npe)) THEN
            CPABORT("Number of processors is not divisible by the kpoint group size.")
         END IF

         ! Create the subgroups, one for each k-point group and one interconnecting group
         CALL comm_cart%create(comm_old=para_env, ndims=2, dims=dims)
         pos = comm_cart%mepos_cart
         ALLOCATE (para_env_kp)
         CALL para_env_kp%from_split(comm_cart, pos(2))
         ALLOCATE (para_env_inter_kp)
         CALL para_env_inter_kp%from_split(comm_cart, pos(1))
         CALL comm_cart%free()

         niogrp = 0
         IF (para_env%is_source()) niogrp = 1
         CALL para_env_kp%sum(niogrp)
         kpoint%iogrp = (niogrp == 1)

         ! parallel groups
         kpoint%para_env_kp => para_env_kp
         kpoint%para_env_inter_kp => para_env_inter_kp

         ! distribution of kpoints
         ALLOCATE (kpoint%kp_dist(2, nkp_grp))
         DO igr = 1, nkp_grp
            kpoint%kp_dist(1:2, igr) = get_limit(nkp, nkp_grp, igr - 1)
         END DO
         ! local kpoints
         kpoint%kp_range(1:2) = kpoint%kp_dist(1:2, para_env_inter_kp%mepos + 1)

         ALLOCATE (kp_env(nkp_loc))
         DO ik = 1, nkp_loc
            NULLIFY (kp_env(ik)%kpoint_env)
            ikk = kpoint%kp_range(1) + ik - 1
            CALL kpoint_env_create(kp_env(ik)%kpoint_env)
            kp => kp_env(ik)%kpoint_env
            kp%nkpoint = ikk
            kp%wkp = kpoint%wkp(ikk)
            kp%xkp(1:3) = kpoint%xkp(1:3, ikk)
            kp%is_local = (ngr == 1)
         END DO
         kpoint%kp_env => kp_env

         IF (aux_fit) THEN
            ALLOCATE (kp_aux_env(nkp_loc))
            DO ik = 1, nkp_loc
               NULLIFY (kp_aux_env(ik)%kpoint_env)
               ikk = kpoint%kp_range(1) + ik - 1
               CALL kpoint_env_create(kp_aux_env(ik)%kpoint_env)
               kp => kp_aux_env(ik)%kpoint_env
               kp%nkpoint = ikk
               kp%wkp = kpoint%wkp(ikk)
               kp%xkp(1:3) = kpoint%xkp(1:3, ikk)
               kp%is_local = (ngr == 1)
            END DO
            kpoint%kp_aux_env => kp_aux_env
         END IF

         unit_nr = cp_logger_get_default_io_unit()

         IF (unit_nr > 0 .AND. kpoint%verbose) THEN
            WRITE (unit_nr, *)
            WRITE (unit_nr, FMT="(T2,A,T71,I10)") "KPOINTS| Number of kpoint groups ", nkp_grp
            WRITE (unit_nr, FMT="(T2,A,T71,I10)") "KPOINTS| Size of each kpoint group", ngr
            WRITE (unit_nr, FMT="(T2,A,T71,I10)") "KPOINTS| Number of kpoints per group", nkp_loc
         END IF
         kpoint%nkp_groups = nkp_grp

      END IF

      CALL timestop(handle)

   END SUBROUTINE kpoint_env_initialize

! **************************************************************************************************
!> \brief Initialize a set of MOs and density matrix for each kpoint (kpoint group)
!> \param kpoint  Kpoint environment
!> \param mos     Reference MOs (global)
!> \param added_mos ...
!> \param for_aux_fit ...
! **************************************************************************************************
   SUBROUTINE kpoint_initialize_mos(kpoint, mos, added_mos, for_aux_fit)

      TYPE(kpoint_type), POINTER                         :: kpoint
      TYPE(mo_set_type), DIMENSION(:), INTENT(INOUT)     :: mos
      INTEGER, INTENT(IN), OPTIONAL                      :: added_mos
      LOGICAL, OPTIONAL                                  :: for_aux_fit

      CHARACTER(LEN=*), PARAMETER :: routineN = 'kpoint_initialize_mos'

      INTEGER                                            :: handle, ic, ik, is, nadd, nao, nc, &
                                                            nelectron, nkp_loc, nmo, nmorig(2), &
                                                            nspin
      LOGICAL                                            :: aux_fit
      REAL(KIND=dp)                                      :: flexible_electron_count, maxocc, n_el_f
      TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
      TYPE(cp_fm_pool_p_type), DIMENSION(:), POINTER     :: ao_ao_fm_pools
      TYPE(cp_fm_struct_type), POINTER                   :: matrix_struct
      TYPE(cp_fm_type), POINTER                          :: fmlocal
      TYPE(kpoint_env_type), POINTER                     :: kp
      TYPE(qs_matrix_pools_type), POINTER                :: mpools

      CALL timeset(routineN, handle)

      IF (PRESENT(for_aux_fit)) THEN
         aux_fit = for_aux_fit
      ELSE
         aux_fit = .FALSE.
      END IF

      CPASSERT(ASSOCIATED(kpoint))

      IF (.TRUE. .OR. ASSOCIATED(mos(1)%mo_coeff)) THEN
         IF (aux_fit) THEN
            CPASSERT(ASSOCIATED(kpoint%kp_aux_env))
         END IF

         IF (PRESENT(added_mos)) THEN
            nadd = added_mos
         ELSE
            nadd = 0
         END IF

         IF (kpoint%use_real_wfn) THEN
            nc = 1
         ELSE
            nc = 2
         END IF
         nspin = SIZE(mos, 1)
         nkp_loc = kpoint%kp_range(2) - kpoint%kp_range(1) + 1
         IF (nkp_loc > 0) THEN
            IF (aux_fit) THEN
               CPASSERT(SIZE(kpoint%kp_aux_env) == nkp_loc)
            ELSE
               CPASSERT(SIZE(kpoint%kp_env) == nkp_loc)
            END IF
            ! allocate the mo sets, correct number of kpoints (local), real/complex, spin
            DO ik = 1, nkp_loc
               IF (aux_fit) THEN
                  kp => kpoint%kp_aux_env(ik)%kpoint_env
               ELSE
                  kp => kpoint%kp_env(ik)%kpoint_env
               END IF
               ALLOCATE (kp%mos(nc, nspin))
               DO is = 1, nspin
                  CALL get_mo_set(mos(is), nao=nao, nmo=nmo, nelectron=nelectron, &
                                  n_el_f=n_el_f, maxocc=maxocc, flexible_electron_count=flexible_electron_count)
                  nmo = MIN(nao, nmo + nadd)
                  DO ic = 1, nc
                     CALL allocate_mo_set(kp%mos(ic, is), nao, nmo, nelectron, n_el_f, maxocc, &
                                          flexible_electron_count)
                  END DO
               END DO
            END DO

            ! generate the blacs environment for the kpoint group
            ! we generate a blacs env for each kpoint group in parallel
            ! we assume here that the group para_env_inter_kp will connect
            ! equivalent parts of fm matrices, i.e. no reshuffeling of processors
            NULLIFY (blacs_env)
            IF (ASSOCIATED(kpoint%blacs_env)) THEN
               blacs_env => kpoint%blacs_env
            ELSE
               CALL cp_blacs_env_create(blacs_env=blacs_env, para_env=kpoint%para_env_kp)
               kpoint%blacs_env => blacs_env
            END IF

            ! set possible new number of MOs
            DO is = 1, nspin
               CALL get_mo_set(mos(is), nmo=nmorig(is))
               nmo = MIN(nao, nmorig(is) + nadd)
               CALL set_mo_set(mos(is), nmo=nmo)
            END DO
            ! matrix pools for the kpoint group, information on MOs is transferred using
            ! generic mos structure
            NULLIFY (mpools)
            CALL mpools_create(mpools=mpools)
            CALL mpools_rebuild_fm_pools(mpools=mpools, mos=mos, &
                                         blacs_env=blacs_env, para_env=kpoint%para_env_kp)

            IF (aux_fit) THEN
               kpoint%mpools_aux_fit => mpools
            ELSE
               kpoint%mpools => mpools
            END IF

            ! reset old number of MOs
            DO is = 1, nspin
               CALL set_mo_set(mos(is), nmo=nmorig(is))
            END DO

            ! allocate density matrices
            CALL mpools_get(mpools, ao_ao_fm_pools=ao_ao_fm_pools)
            ALLOCATE (fmlocal)
            CALL fm_pool_create_fm(ao_ao_fm_pools(1)%pool, fmlocal)
            CALL cp_fm_get_info(fmlocal, matrix_struct=matrix_struct)
            DO ik = 1, nkp_loc
               IF (aux_fit) THEN
                  kp => kpoint%kp_aux_env(ik)%kpoint_env
               ELSE
                  kp => kpoint%kp_env(ik)%kpoint_env
               END IF
               ! density matrix
               CALL cp_fm_release(kp%pmat)
               ALLOCATE (kp%pmat(nc, nspin))
               DO is = 1, nspin
                  DO ic = 1, nc
                     CALL cp_fm_create(kp%pmat(ic, is), matrix_struct)
                  END DO
               END DO
               ! energy weighted density matrix
               CALL cp_fm_release(kp%wmat)
               ALLOCATE (kp%wmat(nc, nspin))
               DO is = 1, nspin
                  DO ic = 1, nc
                     CALL cp_fm_create(kp%wmat(ic, is), matrix_struct)
                  END DO
               END DO
            END DO
            CALL fm_pool_give_back_fm(ao_ao_fm_pools(1)%pool, fmlocal)
            DEALLOCATE (fmlocal)

         END IF

      END IF

      CALL timestop(handle)

   END SUBROUTINE kpoint_initialize_mos

! **************************************************************************************************
!> \brief ...
!> \param kpoint ...
! **************************************************************************************************
   SUBROUTINE kpoint_initialize_mo_set(kpoint)
      TYPE(kpoint_type), POINTER                         :: kpoint

      CHARACTER(LEN=*), PARAMETER :: routineN = 'kpoint_initialize_mo_set'

      INTEGER                                            :: handle, ic, ik, ikk, ispin
      TYPE(cp_fm_pool_p_type), DIMENSION(:), POINTER     :: ao_mo_fm_pools
      TYPE(cp_fm_type), POINTER                          :: mo_coeff
      TYPE(mo_set_type), DIMENSION(:, :), POINTER        :: moskp

      CALL timeset(routineN, handle)

      DO ik = 1, SIZE(kpoint%kp_env)
         CALL mpools_get(kpoint%mpools, ao_mo_fm_pools=ao_mo_fm_pools)
         moskp => kpoint%kp_env(ik)%kpoint_env%mos
         ikk = kpoint%kp_range(1) + ik - 1
         CPASSERT(ASSOCIATED(moskp))
         DO ispin = 1, SIZE(moskp, 2)
            DO ic = 1, SIZE(moskp, 1)
               CALL get_mo_set(moskp(ic, ispin), mo_coeff=mo_coeff)
               IF (.NOT. ASSOCIATED(mo_coeff)) THEN
                  CALL init_mo_set(moskp(ic, ispin), &
                                   fm_pool=ao_mo_fm_pools(ispin)%pool, name="kpoints")
               END IF
            END DO
         END DO
      END DO

      CALL timestop(handle)

   END SUBROUTINE kpoint_initialize_mo_set

! **************************************************************************************************
!> \brief Generates the mapping of cell indices and linear RS index
!>        CELL (0,0,0) is always mapped to index 1
!> \param kpoint    Kpoint environment
!> \param sab_nl    Defining neighbour list
!> \param para_env  Parallel environment
!> \param dft_control ...
! **************************************************************************************************
   SUBROUTINE kpoint_init_cell_index(kpoint, sab_nl, para_env, dft_control)

      TYPE(kpoint_type), POINTER                         :: kpoint
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sab_nl
      TYPE(mp_para_env_type), POINTER                    :: para_env
      TYPE(dft_control_type), POINTER                    :: dft_control

      CHARACTER(LEN=*), PARAMETER :: routineN = 'kpoint_init_cell_index'

      INTEGER                                            :: handle, i1, i2, i3, ic, icount, it, &
                                                            ncount
      INTEGER, DIMENSION(3)                              :: cell, itm
      INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell, list
      INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index, cti
      LOGICAL                                            :: new
      TYPE(neighbor_list_iterator_p_type), &
         DIMENSION(:), POINTER                           :: nl_iterator

      NULLIFY (cell_to_index, index_to_cell)

      CALL timeset(routineN, handle)

      CPASSERT(ASSOCIATED(kpoint))

      ALLOCATE (list(3, 125))
      list = 0
      icount = 1

      CALL neighbor_list_iterator_create(nl_iterator, sab_nl)
      DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
         CALL get_iterator_info(nl_iterator, cell=cell)

         new = .TRUE.
         DO ic = 1, icount
            IF (cell(1) == list(1, ic) .AND. cell(2) == list(2, ic) .AND. &
                cell(3) == list(3, ic)) THEN
               new = .FALSE.
               EXIT
            END IF
         END DO
         IF (new) THEN
            icount = icount + 1
            IF (icount > SIZE(list, 2)) THEN
               CALL reallocate(list, 1, 3, 1, 2*SIZE(list, 2))
            END IF
            list(1:3, icount) = cell(1:3)
         END IF

      END DO
      CALL neighbor_list_iterator_release(nl_iterator)

      itm(1) = MAXVAL(ABS(list(1, 1:icount)))
      itm(2) = MAXVAL(ABS(list(2, 1:icount)))
      itm(3) = MAXVAL(ABS(list(3, 1:icount)))
      CALL para_env%max(itm)
      it = MAXVAL(itm(1:3))
      IF (ASSOCIATED(kpoint%cell_to_index)) THEN
         DEALLOCATE (kpoint%cell_to_index)
      END IF
      ALLOCATE (kpoint%cell_to_index(-itm(1):itm(1), -itm(2):itm(2), -itm(3):itm(3)))
      cell_to_index => kpoint%cell_to_index
      cti => cell_to_index
      cti(:, :, :) = 0
      DO ic = 1, icount
         i1 = list(1, ic)
         i2 = list(2, ic)
         i3 = list(3, ic)
         cti(i1, i2, i3) = ic
      END DO
      CALL para_env%sum(cti)
      ncount = 0
      DO i1 = -itm(1), itm(1)
         DO i2 = -itm(2), itm(2)
            DO i3 = -itm(3), itm(3)
               IF (cti(i1, i2, i3) == 0) THEN
                  cti(i1, i2, i3) = 1000000
               ELSE
                  ncount = ncount + 1
                  cti(i1, i2, i3) = (ABS(i1) + ABS(i2) + ABS(i3))*1000 + ABS(i3)*100 + ABS(i2)*10 + ABS(i1)
                  cti(i1, i2, i3) = cti(i1, i2, i3) + (i1 + i2 + i3)
               END IF
            END DO
         END DO
      END DO

      IF (ASSOCIATED(kpoint%index_to_cell)) THEN
         DEALLOCATE (kpoint%index_to_cell)
      END IF
      ALLOCATE (kpoint%index_to_cell(3, ncount))
      index_to_cell => kpoint%index_to_cell
      DO ic = 1, ncount
         cell = MINLOC(cti)
         i1 = cell(1) - 1 - itm(1)
         i2 = cell(2) - 1 - itm(2)
         i3 = cell(3) - 1 - itm(3)
         cti(i1, i2, i3) = 1000000
         index_to_cell(1, ic) = i1
         index_to_cell(2, ic) = i2
         index_to_cell(3, ic) = i3
      END DO
      cti(:, :, :) = 0
      DO ic = 1, ncount
         i1 = index_to_cell(1, ic)
         i2 = index_to_cell(2, ic)
         i3 = index_to_cell(3, ic)
         cti(i1, i2, i3) = ic
      END DO

      ! keep pointer to this neighborlist
      kpoint%sab_nl => sab_nl

      ! set number of images
      dft_control%nimages = SIZE(index_to_cell, 2)

      DEALLOCATE (list)

      CALL timestop(handle)
   END SUBROUTINE kpoint_init_cell_index

! **************************************************************************************************
!> \brief Transformation of real space matrices to a kpoint
!> \param rmatrix  Real part of kpoint matrix
!> \param cmatrix  Complex part of kpoint matrix (optional)
!> \param rsmat    Real space matrices
!> \param ispin    Spin index
!> \param xkp      Kpoint coordinates
!> \param cell_to_index   mapping of cell indices to RS index
!> \param sab_nl   Defining neighbor list
!> \param is_complex  Matrix to be transformed is imaginary
!> \param rs_sign  Matrix to be transformed is csaled by rs_sign
! **************************************************************************************************
   SUBROUTINE rskp_transform(rmatrix, cmatrix, rsmat, ispin, &
                             xkp, cell_to_index, sab_nl, is_complex, rs_sign)

      TYPE(dbcsr_type)                                   :: rmatrix
      TYPE(dbcsr_type), OPTIONAL                         :: cmatrix
      TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: rsmat
      INTEGER, INTENT(IN)                                :: ispin
      REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: xkp
      INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sab_nl
      LOGICAL, INTENT(IN), OPTIONAL                      :: is_complex
      REAL(KIND=dp), INTENT(IN), OPTIONAL                :: rs_sign

      CHARACTER(LEN=*), PARAMETER                        :: routineN = 'rskp_transform'

      INTEGER                                            :: handle, iatom, ic, icol, irow, jatom, &
                                                            nimg
      INTEGER, DIMENSION(3)                              :: cell
      LOGICAL                                            :: do_symmetric, found, my_complex, &
                                                            wfn_real_only
      REAL(KIND=dp)                                      :: arg, coskl, fsign, fsym, sinkl
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: cblock, rblock, rsblock
      TYPE(neighbor_list_iterator_p_type), &
         DIMENSION(:), POINTER                           :: nl_iterator

      CALL timeset(routineN, handle)

      my_complex = .FALSE.
      IF (PRESENT(is_complex)) my_complex = is_complex

      fsign = 1.0_dp
      IF (PRESENT(rs_sign)) fsign = rs_sign

      wfn_real_only = .TRUE.
      IF (PRESENT(cmatrix)) wfn_real_only = .FALSE.

      nimg = SIZE(rsmat, 2)

      CALL get_neighbor_list_set_p(neighbor_list_sets=sab_nl, symmetric=do_symmetric)

      CALL neighbor_list_iterator_create(nl_iterator, sab_nl)
      DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
         CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell)

         ! fsym = +- 1 is due to real space matrices being non-symmetric (although in a symmtric type)
         ! with the link S_mu^0,nu^b = S_nu^0,mu^-b, and the KP matrices beeing Hermitian
         fsym = 1.0_dp
         irow = iatom
         icol = jatom
         IF (do_symmetric .AND. (iatom > jatom)) THEN
            irow = jatom
            icol = iatom
            fsym = -1.0_dp
         END IF

         ic = cell_to_index(cell(1), cell(2), cell(3))
         IF (ic < 1 .OR. ic > nimg) CYCLE

         arg = REAL(cell(1), dp)*xkp(1) + REAL(cell(2), dp)*xkp(2) + REAL(cell(3), dp)*xkp(3)
         IF (my_complex) THEN
            coskl = fsign*fsym*COS(twopi*arg)
            sinkl = fsign*SIN(twopi*arg)
         ELSE
            coskl = fsign*COS(twopi*arg)
            sinkl = fsign*fsym*SIN(twopi*arg)
         END IF

         CALL dbcsr_get_block_p(matrix=rsmat(ispin, ic)%matrix, row=irow, col=icol, &
                                block=rsblock, found=found)
         IF (.NOT. found) CYCLE

         IF (wfn_real_only) THEN
            CALL dbcsr_get_block_p(matrix=rmatrix, row=irow, col=icol, &
                                   block=rblock, found=found)
            IF (.NOT. found) CYCLE
            rblock = rblock + coskl*rsblock
         ELSE
            CALL dbcsr_get_block_p(matrix=rmatrix, row=irow, col=icol, &
                                   block=rblock, found=found)
            IF (.NOT. found) CYCLE
            CALL dbcsr_get_block_p(matrix=cmatrix, row=irow, col=icol, &
                                   block=cblock, found=found)
            IF (.NOT. found) CYCLE
            rblock = rblock + coskl*rsblock
            cblock = cblock + sinkl*rsblock
         END IF

      END DO
      CALL neighbor_list_iterator_release(nl_iterator)

      CALL timestop(handle)

   END SUBROUTINE rskp_transform

! **************************************************************************************************
!> \brief Given the eigenvalues of all kpoints, calculates the occupation numbers
!> \param kpoint  Kpoint environment
!> \param smear   Smearing information
! **************************************************************************************************
   SUBROUTINE kpoint_set_mo_occupation(kpoint, smear)

      TYPE(kpoint_type), POINTER                         :: kpoint
      TYPE(smear_type), POINTER                          :: smear

      CHARACTER(LEN=*), PARAMETER :: routineN = 'kpoint_set_mo_occupation'

      INTEGER                                            :: handle, ik, ikpgr, ispin, kplocal, nb, &
                                                            ne_a, ne_b, nelectron, nkp, nmo, nspin
      INTEGER, DIMENSION(2)                              :: kp_range
      REAL(KIND=dp)                                      :: kTS, mu, mus(2), nel
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: weig, wocc
      REAL(KIND=dp), DIMENSION(:), POINTER               :: eigenvalues, occupation, wkp
      TYPE(kpoint_env_type), POINTER                     :: kp
      TYPE(mo_set_type), POINTER                         :: mo_set
      TYPE(mp_para_env_type), POINTER                    :: para_env_inter_kp

      CALL timeset(routineN, handle)

      ! first collect all the eigenvalues
      CALL get_kpoint_info(kpoint, nkp=nkp)
      kp => kpoint%kp_env(1)%kpoint_env
      nspin = SIZE(kp%mos, 2)
      mo_set => kp%mos(1, 1)
      CALL get_mo_set(mo_set, nmo=nmo, nelectron=nelectron)
      ne_a = nelectron
      IF (nspin == 2) THEN
         CALL get_mo_set(kp%mos(1, 2), nmo=nb, nelectron=ne_b)
         CPASSERT(nmo == nb)
      END IF
      ALLOCATE (weig(nmo, nkp, nspin), wocc(nmo, nkp, nspin))
      weig = 0.0_dp
      wocc = 0.0_dp
      CALL get_kpoint_info(kpoint, kp_range=kp_range)
      kplocal = kp_range(2) - kp_range(1) + 1
      DO ikpgr = 1, kplocal
         ik = kp_range(1) + ikpgr - 1
         kp => kpoint%kp_env(ikpgr)%kpoint_env
         DO ispin = 1, nspin
            mo_set => kp%mos(1, ispin)
            CALL get_mo_set(mo_set, eigenvalues=eigenvalues)
            weig(1:nmo, ik, ispin) = eigenvalues(1:nmo)
         END DO
      END DO
      CALL get_kpoint_info(kpoint, para_env_inter_kp=para_env_inter_kp)
      CALL para_env_inter_kp%sum(weig)

      CALL get_kpoint_info(kpoint, wkp=wkp)
      IF (smear%do_smear) THEN
         ! finite electronic temperature
         SELECT CASE (smear%method)
         CASE (smear_fermi_dirac)
            IF (nspin == 1) THEN
               nel = REAL(nelectron, KIND=dp)
               CALL Fermikp(wocc(:, :, 1), mus(1), kTS, weig(:, :, 1), nel, wkp, &
                            smear%electronic_temperature, 2.0_dp)
            ELSE IF (smear%fixed_mag_mom > 0.0_dp) THEN
               CPABORT("kpoints: Smearing with fixed magnetic moments not (yet) supported")
               nel = REAL(ne_a, KIND=dp)
               CALL Fermikp(wocc(:, :, 1), mus(1), kTS, weig(:, :, 1), nel, wkp, &
                            smear%electronic_temperature, 1.0_dp)
               nel = REAL(ne_b, KIND=dp)
               CALL Fermikp(wocc(:, :, 2), mus(2), kTS, weig(:, :, 2), nel, wkp, &
                            smear%electronic_temperature, 1.0_dp)
            ELSE
               nel = REAL(ne_a, KIND=dp) + REAL(ne_b, KIND=dp)
               CALL Fermikp2(wocc(:, :, :), mu, kTS, weig(:, :, :), nel, wkp, &
                             smear%electronic_temperature)
               kTS = kTS/2._dp
               mus(1:2) = mu
            END IF
         CASE DEFAULT
            CPABORT("kpoints: Selected smearing not (yet) supported")
         END SELECT
      ELSE
         ! fixed occupations (2/1)
         IF (nspin == 1) THEN
            nel = REAL(nelectron, KIND=dp)
            CALL Fermikp(wocc(:, :, 1), mus(1), kTS, weig(:, :, 1), nel, wkp, 0.0_dp, 2.0_dp)
         ELSE
            nel = REAL(ne_a, KIND=dp)
            CALL Fermikp(wocc(:, :, 1), mus(1), kTS, weig(:, :, 1), nel, wkp, 0.0_dp, 1.0_dp)
            nel = REAL(ne_b, KIND=dp)
            CALL Fermikp(wocc(:, :, 2), mus(2), kTS, weig(:, :, 2), nel, wkp, 0.0_dp, 1.0_dp)
         END IF
      END IF
      DO ikpgr = 1, kplocal
         ik = kp_range(1) + ikpgr - 1
         kp => kpoint%kp_env(ikpgr)%kpoint_env
         DO ispin = 1, nspin
            mo_set => kp%mos(1, ispin)
            CALL get_mo_set(mo_set, eigenvalues=eigenvalues, occupation_numbers=occupation)
            eigenvalues(1:nmo) = weig(1:nmo, ik, ispin)
            occupation(1:nmo) = wocc(1:nmo, ik, ispin)
            mo_set%kTS = kTS
            mo_set%mu = mus(ispin)
         END DO
      END DO

      DEALLOCATE (weig, wocc)

      CALL timestop(handle)

   END SUBROUTINE kpoint_set_mo_occupation

! **************************************************************************************************
!> \brief Calculate kpoint density matrices (rho(k), owned by kpoint groups)
!> \param kpoint    kpoint environment
!> \param energy_weighted  calculate energy weighted density matrix
!> \param for_aux_fit ...
! **************************************************************************************************
   SUBROUTINE kpoint_density_matrices(kpoint, energy_weighted, for_aux_fit)

      TYPE(kpoint_type), POINTER                         :: kpoint
      LOGICAL, OPTIONAL                                  :: energy_weighted, for_aux_fit

      CHARACTER(LEN=*), PARAMETER :: routineN = 'kpoint_density_matrices'

      INTEGER                                            :: handle, ikpgr, ispin, kplocal, nao, nmo, &
                                                            nspin
      INTEGER, DIMENSION(2)                              :: kp_range
      LOGICAL                                            :: aux_fit, wtype
      REAL(KIND=dp), DIMENSION(:), POINTER               :: eigenvalues, occupation
      TYPE(cp_fm_struct_type), POINTER                   :: matrix_struct
      TYPE(cp_fm_type)                                   :: fwork
      TYPE(cp_fm_type), POINTER                          :: cpmat, pmat, rpmat
      TYPE(kpoint_env_type), POINTER                     :: kp
      TYPE(mo_set_type), POINTER                         :: mo_set

      CALL timeset(routineN, handle)

      IF (PRESENT(energy_weighted)) THEN
         wtype = energy_weighted
      ELSE
         ! default is normal density matrix
         wtype = .FALSE.
      END IF

      IF (PRESENT(for_aux_fit)) THEN
         aux_fit = for_aux_fit
      ELSE
         aux_fit = .FALSE.
      END IF

      IF (aux_fit) THEN
         CPASSERT(ASSOCIATED(kpoint%kp_aux_env))
      END IF

      ! work matrix
      IF (aux_fit) THEN
         mo_set => kpoint%kp_aux_env(1)%kpoint_env%mos(1, 1)
      ELSE
         mo_set => kpoint%kp_env(1)%kpoint_env%mos(1, 1)
      END IF
      CALL get_mo_set(mo_set, nao=nao, nmo=nmo)
      CALL cp_fm_get_info(mo_set%mo_coeff, matrix_struct=matrix_struct)
      CALL cp_fm_create(fwork, matrix_struct)

      CALL get_kpoint_info(kpoint, kp_range=kp_range)
      kplocal = kp_range(2) - kp_range(1) + 1
      DO ikpgr = 1, kplocal
         IF (aux_fit) THEN
            kp => kpoint%kp_aux_env(ikpgr)%kpoint_env
         ELSE
            kp => kpoint%kp_env(ikpgr)%kpoint_env
         END IF
         nspin = SIZE(kp%mos, 2)
         DO ispin = 1, nspin
            mo_set => kp%mos(1, ispin)
            IF (wtype) THEN
               CALL get_mo_set(mo_set, eigenvalues=eigenvalues)
            END IF
            IF (kpoint%use_real_wfn) THEN
               IF (wtype) THEN
                  pmat => kp%wmat(1, ispin)
               ELSE
                  pmat => kp%pmat(1, ispin)
               END IF
               CALL get_mo_set(mo_set, occupation_numbers=occupation)
               CALL cp_fm_to_fm(mo_set%mo_coeff, fwork)
               CALL cp_fm_column_scale(fwork, occupation)
               IF (wtype) THEN
                  CALL cp_fm_column_scale(fwork, eigenvalues)
               END IF
               CALL parallel_gemm("N", "T", nao, nao, nmo, 1.0_dp, mo_set%mo_coeff, fwork, 0.0_dp, pmat)
            ELSE
               IF (wtype) THEN
                  rpmat => kp%wmat(1, ispin)
                  cpmat => kp%wmat(2, ispin)
               ELSE
                  rpmat => kp%pmat(1, ispin)
                  cpmat => kp%pmat(2, ispin)
               END IF
               CALL get_mo_set(mo_set, occupation_numbers=occupation)
               CALL cp_fm_to_fm(mo_set%mo_coeff, fwork)
               CALL cp_fm_column_scale(fwork, occupation)
               IF (wtype) THEN
                  CALL cp_fm_column_scale(fwork, eigenvalues)
               END IF
               ! Re(c)*Re(c)
               CALL parallel_gemm("N", "T", nao, nao, nmo, 1.0_dp, mo_set%mo_coeff, fwork, 0.0_dp, rpmat)
               mo_set => kp%mos(2, ispin)
               ! Im(c)*Re(c)
               CALL parallel_gemm("N", "T", nao, nao, nmo, 1.0_dp, mo_set%mo_coeff, fwork, 0.0_dp, cpmat)
               ! Re(c)*Im(c)
               CALL parallel_gemm("N", "T", nao, nao, nmo, -1.0_dp, fwork, mo_set%mo_coeff, 1.0_dp, cpmat)
               CALL cp_fm_to_fm(mo_set%mo_coeff, fwork)
               CALL cp_fm_column_scale(fwork, occupation)
               IF (wtype) THEN
                  CALL cp_fm_column_scale(fwork, eigenvalues)
               END IF
               ! Im(c)*Im(c)
               CALL parallel_gemm("N", "T", nao, nao, nmo, 1.0_dp, mo_set%mo_coeff, fwork, 1.0_dp, rpmat)
            END IF
         END DO
      END DO

      CALL cp_fm_release(fwork)

      CALL timestop(handle)

   END SUBROUTINE kpoint_density_matrices

! **************************************************************************************************
!> \brief generate real space density matrices in DBCSR format
!> \param kpoint  Kpoint environment
!> \param denmat  Real space (DBCSR) density matrices
!> \param wtype   True = energy weighted density matrix
!>                False = normal density matrix
!> \param tempmat DBCSR matrix to be used as template
!> \param sab_nl ...
!> \param fmwork  FM work matrices (kpoint group)
!> \param for_aux_fit ...
!> \param pmat_ext ...
! **************************************************************************************************
   SUBROUTINE kpoint_density_transform(kpoint, denmat, wtype, tempmat, sab_nl, fmwork, for_aux_fit, pmat_ext)

      TYPE(kpoint_type), POINTER                         :: kpoint
      TYPE(dbcsr_p_type), DIMENSION(:, :)                :: denmat
      LOGICAL, INTENT(IN)                                :: wtype
      TYPE(dbcsr_type), POINTER                          :: tempmat
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sab_nl
      TYPE(cp_fm_type), DIMENSION(:), INTENT(IN)         :: fmwork
      LOGICAL, OPTIONAL                                  :: for_aux_fit
      TYPE(cp_fm_type), DIMENSION(:, :, :), INTENT(IN), &
         OPTIONAL                                        :: pmat_ext

      CHARACTER(LEN=*), PARAMETER :: routineN = 'kpoint_density_transform'

      INTEGER                                            :: handle, ic, ik, ikk, indx, is, ispin, &
                                                            nc, nimg, nkp, nspin
      INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
      LOGICAL                                            :: aux_fit, do_ext, do_symmetric, my_kpgrp, &
                                                            real_only
      REAL(KIND=dp)                                      :: wkpx
      REAL(KIND=dp), DIMENSION(:), POINTER               :: wkp
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: xkp
      TYPE(copy_info_type), ALLOCATABLE, DIMENSION(:)    :: info
      TYPE(cp_fm_type)                                   :: fmdummy
      TYPE(dbcsr_type), POINTER                          :: cpmat, rpmat, scpmat, srpmat
      TYPE(kpoint_env_type), POINTER                     :: kp
      TYPE(kpoint_sym_type), POINTER                     :: kpsym
      TYPE(mp_para_env_type), POINTER                    :: para_env

      CALL timeset(routineN, handle)

      CALL get_neighbor_list_set_p(neighbor_list_sets=sab_nl, symmetric=do_symmetric)

      IF (PRESENT(for_aux_fit)) THEN
         aux_fit = for_aux_fit
      ELSE
         aux_fit = .FALSE.
      END IF

      do_ext = .FALSE.
      IF (PRESENT(pmat_ext)) do_ext = .TRUE.

      IF (aux_fit) THEN
         CPASSERT(ASSOCIATED(kpoint%kp_aux_env))
      END IF

      ! work storage
      ALLOCATE (rpmat)
      CALL dbcsr_create(rpmat, template=tempmat, &
                        matrix_type=MERGE(dbcsr_type_symmetric, dbcsr_type_no_symmetry, do_symmetric))
      CALL cp_dbcsr_alloc_block_from_nbl(rpmat, sab_nl)
      ALLOCATE (cpmat)
      CALL dbcsr_create(cpmat, template=tempmat, &
                        matrix_type=MERGE(dbcsr_type_antisymmetric, dbcsr_type_no_symmetry, do_symmetric))
      CALL cp_dbcsr_alloc_block_from_nbl(cpmat, sab_nl)
      IF (.NOT. kpoint%full_grid) THEN
         ALLOCATE (srpmat)
         CALL dbcsr_create(srpmat, template=rpmat)
         CALL cp_dbcsr_alloc_block_from_nbl(srpmat, sab_nl)
         ALLOCATE (scpmat)
         CALL dbcsr_create(scpmat, template=cpmat)
         CALL cp_dbcsr_alloc_block_from_nbl(scpmat, sab_nl)
      END IF

      CALL get_kpoint_info(kpoint, nkp=nkp, xkp=xkp, wkp=wkp, &
                           cell_to_index=cell_to_index)
      ! initialize real space density matrices
      IF (aux_fit) THEN
         kp => kpoint%kp_aux_env(1)%kpoint_env
      ELSE
         kp => kpoint%kp_env(1)%kpoint_env
      END IF
      nspin = SIZE(kp%mos, 2)
      nc = SIZE(kp%mos, 1)
      nimg = SIZE(denmat, 2)
      real_only = (nc == 1)

      para_env => kpoint%blacs_env_all%para_env
      ALLOCATE (info(nspin*nkp*nc))

      ! Start all the communication
      indx = 0
      DO ispin = 1, nspin
         DO ic = 1, nimg
            CALL dbcsr_set(denmat(ispin, ic)%matrix, 0.0_dp)
         END DO
         !
         DO ik = 1, nkp
            my_kpgrp = (ik >= kpoint%kp_range(1) .AND. ik <= kpoint%kp_range(2))
            IF (my_kpgrp) THEN
               ikk = ik - kpoint%kp_range(1) + 1
               IF (aux_fit) THEN
                  kp => kpoint%kp_aux_env(ikk)%kpoint_env
               ELSE
                  kp => kpoint%kp_env(ikk)%kpoint_env
               END IF
            ELSE
               NULLIFY (kp)
            END IF
            ! collect this density matrix on all processors
            CPASSERT(SIZE(fmwork) >= nc)

            IF (my_kpgrp) THEN
               DO ic = 1, nc
                  indx = indx + 1
                  IF (do_ext) THEN
                     CALL cp_fm_start_copy_general(pmat_ext(ikk, ic, ispin), fmwork(ic), para_env, info(indx))
                  ELSE
                     IF (wtype) THEN
                        CALL cp_fm_start_copy_general(kp%wmat(ic, ispin), fmwork(ic), para_env, info(indx))
                     ELSE
                        CALL cp_fm_start_copy_general(kp%pmat(ic, ispin), fmwork(ic), para_env, info(indx))
                     END IF
                  END IF
               END DO
            ELSE
               DO ic = 1, nc
                  indx = indx + 1
                  CALL cp_fm_start_copy_general(fmdummy, fmwork(ic), para_env, info(indx))
               END DO
            END IF
         END DO
      END DO

      ! Finish communication and transform the received matrices
      indx = 0
      DO ispin = 1, nspin
         DO ik = 1, nkp
            DO ic = 1, nc
               indx = indx + 1
               CALL cp_fm_finish_copy_general(fmwork(ic), info(indx))
            END DO

            ! reduce to dbcsr storage
            IF (real_only) THEN
               CALL copy_fm_to_dbcsr(fmwork(1), rpmat, keep_sparsity=.TRUE.)
            ELSE
               CALL copy_fm_to_dbcsr(fmwork(1), rpmat, keep_sparsity=.TRUE.)
               CALL copy_fm_to_dbcsr(fmwork(2), cpmat, keep_sparsity=.TRUE.)
            END IF

            ! symmetrization
            kpsym => kpoint%kp_sym(ik)%kpoint_sym
            CPASSERT(ASSOCIATED(kpsym))

            IF (kpsym%apply_symmetry) THEN
               wkpx = wkp(ik)/REAL(kpsym%nwght, KIND=dp)
               DO is = 1, kpsym%nwght
                  IF (real_only) THEN
                     CALL symtrans(srpmat, rpmat, kpsym%rot(1:3, 1:3, is), kpsym%f0(:, is), symmetric=.TRUE.)
                  ELSE
                     CALL symtrans(srpmat, rpmat, kpsym%rot(1:3, 1:3, is), kpsym%f0(:, is), symmetric=.TRUE.)
                     CALL symtrans(scpmat, cpmat, kpsym%rot(1:3, 1:3, is), kpsym%f0(:, is), antisymmetric=.TRUE.)
                  END IF
                  CALL transform_dmat(denmat, srpmat, scpmat, ispin, real_only, sab_nl, &
                                      cell_to_index, kpsym%xkp(1:3, is), wkpx)
               END DO
            ELSE
               ! transformation
               CALL transform_dmat(denmat, rpmat, cpmat, ispin, real_only, sab_nl, &
                                   cell_to_index, xkp(1:3, ik), wkp(ik))
            END IF
         END DO
      END DO

      ! Clean up communication
      indx = 0
      DO ispin = 1, nspin
         DO ik = 1, nkp
            my_kpgrp = (ik >= kpoint%kp_range(1) .AND. ik <= kpoint%kp_range(2))
            IF (my_kpgrp) THEN
               ikk = ik - kpoint%kp_range(1) + 1
               IF (aux_fit) THEN
                  kp => kpoint%kp_aux_env(ikk)%kpoint_env
               ELSE
                  kp => kpoint%kp_env(ikk)%kpoint_env
               END IF

               DO ic = 1, nc
                  indx = indx + 1
                  CALL cp_fm_cleanup_copy_general(info(indx))
               END DO
            ELSE
               ! calls with dummy arguments, so not included
               ! therefore just increment counter by trip count
               indx = indx + nc
            END IF
         END DO
      END DO

      ! All done
      DEALLOCATE (info)

      CALL dbcsr_deallocate_matrix(rpmat)
      CALL dbcsr_deallocate_matrix(cpmat)
      IF (.NOT. kpoint%full_grid) THEN
         CALL dbcsr_deallocate_matrix(srpmat)
         CALL dbcsr_deallocate_matrix(scpmat)
      END IF

      CALL timestop(handle)

   END SUBROUTINE kpoint_density_transform

! **************************************************************************************************
!> \brief real space density matrices in DBCSR format
!> \param denmat  Real space (DBCSR) density matrix
!> \param rpmat ...
!> \param cpmat ...
!> \param ispin ...
!> \param real_only ...
!> \param sab_nl ...
!> \param cell_to_index ...
!> \param xkp ...
!> \param wkp ...
! **************************************************************************************************
   SUBROUTINE transform_dmat(denmat, rpmat, cpmat, ispin, real_only, sab_nl, cell_to_index, xkp, wkp)

      TYPE(dbcsr_p_type), DIMENSION(:, :)                :: denmat
      TYPE(dbcsr_type), POINTER                          :: rpmat, cpmat
      INTEGER, INTENT(IN)                                :: ispin
      LOGICAL, INTENT(IN)                                :: real_only
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sab_nl
      INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
      REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: xkp
      REAL(KIND=dp), INTENT(IN)                          :: wkp

      CHARACTER(LEN=*), PARAMETER                        :: routineN = 'transform_dmat'

      INTEGER                                            :: handle, iatom, icell, icol, irow, jatom, &
                                                            nimg
      INTEGER, DIMENSION(3)                              :: cell
      LOGICAL                                            :: do_symmetric, found
      REAL(KIND=dp)                                      :: arg, coskl, fc, sinkl
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: cblock, dblock, rblock
      TYPE(neighbor_list_iterator_p_type), &
         DIMENSION(:), POINTER                           :: nl_iterator

      CALL timeset(routineN, handle)

      nimg = SIZE(denmat, 2)

      ! transformation
      CALL get_neighbor_list_set_p(neighbor_list_sets=sab_nl, symmetric=do_symmetric)
      CALL neighbor_list_iterator_create(nl_iterator, sab_nl)
      DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
         CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell)

         !We have a FT from KP to real-space: S(R) = sum_k S(k)*exp(-i*k*R), with S(k) a complex number
         !Therefore, we have: S(R) = sum_k Re(S(k))*cos(k*R) -i^2*Im(S(k))*sin(k*R)
         !                         = sum_k Re(S(k))*cos(k*R) + Im(S(k))*sin(k*R)
         !fc = +- 1 is due to the usual non-symmetric real-sapce matrices stored as symmetric ones

         irow = iatom
         icol = jatom
         fc = 1.0_dp
         IF (do_symmetric .AND. iatom > jatom) THEN
            irow = jatom
            icol = iatom
            fc = -1.0_dp
         END IF

         icell = cell_to_index(cell(1), cell(2), cell(3))
         IF (icell < 1 .OR. icell > nimg) CYCLE

         arg = REAL(cell(1), dp)*xkp(1) + REAL(cell(2), dp)*xkp(2) + REAL(cell(3), dp)*xkp(3)
         coskl = wkp*COS(twopi*arg)
         sinkl = wkp*fc*SIN(twopi*arg)

         CALL dbcsr_get_block_p(matrix=denmat(ispin, icell)%matrix, row=irow, col=icol, &
                                block=dblock, found=found)
         IF (.NOT. found) CYCLE

         IF (real_only) THEN
            CALL dbcsr_get_block_p(matrix=rpmat, row=irow, col=icol, block=rblock, found=found)
            IF (.NOT. found) CYCLE
            dblock = dblock + coskl*rblock
         ELSE
            CALL dbcsr_get_block_p(matrix=rpmat, row=irow, col=icol, block=rblock, found=found)
            IF (.NOT. found) CYCLE
            CALL dbcsr_get_block_p(matrix=cpmat, row=irow, col=icol, block=cblock, found=found)
            IF (.NOT. found) CYCLE
            dblock = dblock + coskl*rblock
            dblock = dblock + sinkl*cblock
         END IF
      END DO
      CALL neighbor_list_iterator_release(nl_iterator)

      CALL timestop(handle)

   END SUBROUTINE transform_dmat

! **************************************************************************************************
!> \brief Symmetrization of density matrix - transform to new k-point
!> \param smat density matrix at new kpoint
!> \param pmat reference density matrix
!> \param rot Rotation matrix
!> \param f0 Permutation of atoms under transformation
!> \param symmetric Symmetric matrix
!> \param antisymmetric Anti-Symmetric matrix
! **************************************************************************************************
   SUBROUTINE symtrans(smat, pmat, rot, f0, symmetric, antisymmetric)
      TYPE(dbcsr_type), POINTER                          :: smat, pmat
      REAL(KIND=dp), DIMENSION(3, 3), INTENT(IN)         :: rot
      INTEGER, DIMENSION(:), INTENT(IN)                  :: f0
      LOGICAL, INTENT(IN), OPTIONAL                      :: symmetric, antisymmetric

      CHARACTER(LEN=*), PARAMETER                        :: routineN = 'symtrans'

      INTEGER                                            :: handle, iatom, icol, ip, irow, jcol, jp, &
                                                            jrow, natom, numnodes
      LOGICAL                                            :: asym, dorot, found, perm, sym, trans
      REAL(KIND=dp)                                      :: dr, fsign
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: pblock, sblock
      TYPE(dbcsr_distribution_type)                      :: dist
      TYPE(dbcsr_iterator_type)                          :: iter

      CALL timeset(routineN, handle)

      ! check symmetry options
      sym = .FALSE.
      IF (PRESENT(symmetric)) sym = symmetric
      asym = .FALSE.
      IF (PRESENT(antisymmetric)) asym = antisymmetric

      CPASSERT(.NOT. (sym .AND. asym))
      CPASSERT((sym .OR. asym))

      ! do we have permutation of atoms
      natom = SIZE(f0)
      perm = .FALSE.
      DO iatom = 1, natom
         IF (f0(iatom) == iatom) CYCLE
         perm = .TRUE.
         EXIT
      END DO

      ! do we have a real rotation
      dorot = .FALSE.
      IF (ABS(SUM(ABS(rot)) - 3.0_dp) > 1.e-12_dp) dorot = .TRUE.
      dr = ABS(rot(1, 1) - 1.0_dp) + ABS(rot(2, 2) - 1.0_dp) + ABS(rot(3, 3) - 1.0_dp)
      IF (ABS(dr) > 1.e-12_dp) dorot = .TRUE.

      fsign = 1.0_dp
      IF (asym) fsign = -1.0_dp

      IF (dorot .OR. perm) THEN
         CALL dbcsr_set(smat, 0.0_dp)
         IF (perm) THEN
            CALL dbcsr_get_info(pmat, distribution=dist)
            CALL dbcsr_distribution_get(dist, numnodes=numnodes)
            IF (numnodes == 1) THEN
               ! the matrices are local to this process
               CALL dbcsr_iterator_start(iter, pmat)
               DO WHILE (dbcsr_iterator_blocks_left(iter))
                  CALL dbcsr_iterator_next_block(iter, irow, icol, pblock)
                  ip = f0(irow)
                  jp = f0(icol)
                  IF (ip <= jp) THEN
                     jrow = ip
                     jcol = jp
                     trans = .FALSE.
                  ELSE
                     jrow = jp
                     jcol = ip
                     trans = .TRUE.
                  END IF
                  CALL dbcsr_get_block_p(matrix=smat, row=jrow, col=jcol, BLOCK=sblock, found=found)
                  IF (.NOT. found) CYCLE
                  IF (trans) THEN
                     sblock = fsign*TRANSPOSE(pblock)
                  ELSE
                     sblock = pblock
                  END IF
               END DO
               CALL dbcsr_iterator_stop(iter)
               !
            ELSE
               ! distributed matrices, most general code needed
               CALL cp_abort(__LOCATION__, "k-points need FULL_GRID currently. "// &
                             "Reduced grids not yet working correctly")
            END IF
         ELSE
            ! no atom permutations, this is always local
            ! ignore rotations for now
            CALL dbcsr_copy(smat, pmat)
            CALL cp_abort(__LOCATION__, "k-points need FULL_GRID currently. "// &
                          "Reduced grids not yet working correctly")
         END IF
      ELSE
         ! this is the identity operation, just copy the matrix
         CALL dbcsr_copy(smat, pmat)
      END IF

      CALL timestop(handle)

   END SUBROUTINE symtrans

! **************************************************************************************************

END MODULE kpoint_methods
